import math
from itertools import combinations
from typing import List, Tuple, Dict

import numpy as np


def compute_orientation(seg):
    p0, p1 = seg[0], seg[-1]
    return math.atan2(p1[1] - p0[1], p1[0] - p0[0])


def compute_average_curvature(seg):
    n = seg.shape[0]
    if n < 3:
        return 0.0
    curvs = []
    for i in range(1, n - 1):
        p_prev, p, p_next = seg[i - 1], seg[i], seg[i + 1]
        v1 = p - p_prev
        v2 = p_next - p
        norm1 = np.linalg.norm(v1)
        norm2 = np.linalg.norm(v2)
        if norm1 == 0 or norm2 == 0:
            continue
        cosang = np.dot(v1, v2) / (norm1 * norm2)
        cosang = max(-1.0, min(1.0, cosang))
        angle = math.acos(cosang)
        curvs.append(abs(angle))
    return float(np.mean(curvs)) if curvs else 0.0


def min_endpoint_distance(seg_i, seg_j):
    ends_i = np.stack([seg_i[0], seg_i[-1]])
    ends_j = np.stack([seg_j[0], seg_j[-1]])
    dists = np.linalg.norm(ends_i[:, None, :] - ends_j[None, :, :], axis=2)
    return float(dists.min())


def angle_diff(a, b):
    d = abs(a - b) % (2 * math.pi)
    if d > math.pi:
        d = 2 * math.pi - d
    return d


def connectivity(seq1, seq2, alpha, beta, gamma, lambda_, mu):
    seq1, seq2 = np.array(seq1), np.array(seq2)
    d_min = min_endpoint_distance(seq1, seq2)
    dtheta = angle_diff(compute_orientation(seq1), compute_orientation(seq2))
    dk = abs(compute_average_curvature(seq1) - compute_average_curvature(seq2))
    C_ij = (
        alpha * math.exp(-lambda_ * d_min)
        + beta * math.cos(dtheta)
        + gamma * math.exp(-mu * dk)
    )
    return C_ij


def stroke_match(routes, meeti, alpha, beta, gamma, lambda_, mu):
    meet_strokes = {}
    for i in meeti.values():
        meet_strokes[i] = []

    mixc_len = 1
    met_strokes, strokes_new, strokes_cvt = [], {}, {}
    i = 0
    for c, sts in routes.items():
        for st, mc in sts:
            if len(st) > mixc_len:
                met_strokes.append(st)
                strokes_new[i] = i
                strokes_cvt[i] = 1
                meet_strokes[meeti[c]].append([i, 1])
                if not mc == (-1, -1):
                    meet_strokes[meeti[mc]].append([i, -1])
                i += 1
    
    meet_cnt = {}
    cnts_len = 10
    for m, sts in meet_strokes.items():
        connectivities = []
        for i in range(len(sts)):
            st = met_strokes[sts[i][0]]
            if sts[i][1] == 1:
                seq1 = st[:cnts_len]
            else:
                seq1 = st[:-cnts_len - 1:-1]
            for j in range(i + 1, len(sts)):
                st = met_strokes[sts[j][0]]
                if sts[j][1] == 1:
                    seq2 = st[:cnts_len]
                else:
                    seq2 = st[:-cnts_len - 1:-1]
                connectivities.append((i, j, connectivity(seq1, seq2, alpha, beta, gamma, lambda_, mu)))
        connectivities.sort(key=lambda x: x[2], reverse=True)
        matched = set()
        matches: List[Tuple[int, int]] = []
        cs, cnts = [], []
        for i, j, c in connectivities:
            if i not in matched and j not in matched:
                matches.append([i, j])
                matched.add(i)
                matched.add(j)
        meet_cnt[m] = matches

    return meet_strokes, met_strokes, strokes_new, strokes_cvt, meet_cnt


def stroke_connect(meet_strokes, met_strokes, strokes_new, strokes_cvt, meet_cnt, sim_strokes):
    d_sis = []
    for m, sts in meet_strokes.items():
        for cnt in meet_cnt[m]:
            si1 = sts[cnt[0]][0]
            si2 = sts[cnt[1]][0]
            sni1 = strokes_new[si1]
            sni2 = strokes_new[si2]
            dir1 = sts[cnt[0]][1] * strokes_cvt[si1]
            dir2 = sts[cnt[1]][1] * strokes_cvt[si2]
            if dir1 == 1 and dir2 == 1:
                met_strokes[sni1] = met_strokes[sni2][::-1] + met_strokes[sni1]
                if sni1 != sni2:
                    for si in strokes_new:
                        if strokes_new[si] == sni2:
                            strokes_new[si] = sni1
                    d_sis.append(sni2)
                strokes_cvt[si2] *= -1
            elif dir1 == 1 and dir2 == -1:
                met_strokes[sni2] += met_strokes[sni1]
                if sni1 != sni2:
                    for si in strokes_new:
                        if strokes_new[si] == sni1:
                            strokes_new[si] = sni2
                    d_sis.append(sni1)
            elif dir1 == -1 and dir2 == -1:
                met_strokes[sni1] += met_strokes[sni2][::-1]
                if sni1 != sni2:
                    for si in strokes_new:
                        if strokes_new[si] == sni2:
                            strokes_new[si] = sni1
                    d_sis.append(sni2)
                strokes_cvt[si2] *= -1
            elif dir1 == -1 and dir2 == 1:
                met_strokes[sni1] += met_strokes[sni2]
                if sni1 != sni2:
                    for si in strokes_new:
                        if strokes_new[si] == sni2:
                            strokes_new[si] = sni1
                    d_sis.append(sni2)
    strokes = []
    for i in range(len(met_strokes)):
        if not i in d_sis:
            strokes.append(met_strokes[i])
    strokes += sim_strokes

    return strokes
